Prior

Author

Raymond

EDA

I use data from the tail-base, mid-center and nose to get a prior distribution of the three pairs of body part distances. From the histograms, both the actual data and manually labelled distances of body parts seem roughly Gaussian if you discard obvious outliers (mistakes). It seems reasonable to model the distances as Gaussian distributions with the MLE plug-in estimators of the sample mean and standard deviation using manually labelled data.

I also considered using the ratios of distances, but this seems less stable so it was not used in the final model.

library(readr)
library(ggplot2)
library(dplyr)
library(purrr)
library(plotly)
library(tidyverse)
library(tidyr)
coords = read_csv("predictions_aligned_40.csv")
str(coords)
spc_tbl_ [281,296 × 6] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
 $ frame   : num [1:281296] 0 0 0 0 0 0 0 0 0 0 ...
 $ instance: num [1:281296] 0 0 0 0 0 0 0 0 1 1 ...
 $ node    : chr [1:281296] "Ear-left" "Ear-right" "Nose" "Mid-center" ...
 $ x       : num [1:281296] 1080 1025 1085 1036 1080 ...
 $ y       : num [1:281296] 612 648 668 580 588 ...
 $ score   : num [1:281296] 0.466 0.44 0.566 0.544 0.338 ...
 - attr(*, "spec")=
  .. cols(
  ..   frame = col_double(),
  ..   instance = col_double(),
  ..   node = col_character(),
  ..   x = col_double(),
  ..   y = col_double(),
  ..   score = col_double()
  .. )
 - attr(*, "problems")=<externalptr> 
library(dplyr)
coords_wide <- coords |>
  filter(y < 700 | is.na(y)) |>
  select(frame, instance, node, x, y) |>
  pivot_wider(
    names_from = node,
    values_from = c(x, y),
    names_sep = "_"
  )
coords_mid = coords_wide |>
  mutate(dtn = sqrt((x_Nose - `x_Tail-base`)^2 + 
                      (y_Nose - `y_Tail-base`)^2)) |>
  mutate(dtm = sqrt((`x_Mid-center` - `x_Tail-base`)^2 + 
                    (`y_Mid-center` - `y_Tail-base`)^2)) |>
  mutate(dnm = sqrt((`x_Mid-center` - x_Nose)^2 + 
                    (`y_Mid-center` - y_Nose)^2)) |>
  mutate(dist_ratio = dtm / dnm) 

Examine distributions of real data tail-nose distances:

ggplot() +
  geom_histogram(data = coords_mid, aes(x = log(dist_ratio), y = after_stat(density))) +
  labs(title = "Real Distance Ratios")
`stat_bin()` using `bins = 30`. Pick better value `binwidth`.
Warning: Removed 10784 rows containing non-finite outside the scale range
(`stat_bin()`).

ggplot() +
  geom_histogram(data = coords_mid, aes(x = dtn, y = after_stat(density))) +
  labs(title = "Real Tail Nose Distances")
`stat_bin()` using `bins = 30`. Pick better value `binwidth`.
Warning: Removed 10558 rows containing non-finite outside the scale range
(`stat_bin()`).

manual = read_csv("predictions_manual.csv")
manual = manual |>
  filter(score == 1) |>
  filter(frame != 308) |>
  filter(frame != 310) |>
  filter(frame != 3511) |>
  filter(frame != 581)
manual$instance =  rep(rep(0:1, each = 8), 40)

Examine distributions of manually labelled body part distances:

manual_wide <- manual |>
  select(frame, instance, node, x, y) |>
  pivot_wider(
    names_from = node,
    values_from = c(x, y),
    names_sep = "_"
  )
manual_mid = manual_wide |>
  mutate(dtn = sqrt((x_Nose - `x_Tail-base`)^2 + 
                      (y_Nose - `y_Tail-base`)^2)) |>
  mutate(dtm = sqrt((`x_Mid-center` - `x_Tail-base`)^2 + 
                    (`y_Mid-center` - `y_Tail-base`)^2)) |>
  mutate(dnm = sqrt((`x_Mid-center` - x_Nose)^2 + 
                    (`y_Mid-center` - y_Nose)^2)) |>
  mutate(dist_ratio = dtm / dnm)
ggplot() +
  geom_histogram(data = manual_mid, aes(x = dtm, y = after_stat(density))) +
  labs(title = "Tail Mid-center Distances")
`stat_bin()` using `bins = 30`. Pick better value `binwidth`.

ggplot() +
  geom_histogram(data = manual_mid, aes(x = dtn, y = after_stat(density))) +
  labs(title = "Tail Nose Distances")
`stat_bin()` using `bins = 30`. Pick better value `binwidth`.

ggplot() +
  geom_histogram(data = manual_mid, aes(x = dnm, y = after_stat(density))) +
  labs(title = "Nose Mid-center Distances")
`stat_bin()` using `bins = 30`. Pick better value `binwidth`.

pri_dtm_mean = mean(manual_mid$dtm)
pri_dtm_sd = sd(manual_mid$dtm)
pri_dtn_mean = mean(manual_mid$dtn)
pri_dtn_sd = sd(manual_mid$dtn)
pri_dnm_mean = mean(manual_mid$dnm)
pri_dnm_sd = sd(manual_mid$dnm)

Examine the data in a loop in frame order, but with real time updates of the dataframe before going to the next frame. For each frame, try every combination of the coordinates for instance 0 (So for tail-base, mid-center, and nose nodes, there will be 8 combinations, including potential NAs). Instance 1 will get the remaining coordinates. Then, for each combination compute a score based on average likelihood of the 3 distances of nodes according to prior (removing NAs) as well as weighted average distance to the previous frame (the weight decreases as the time gap largens). Update the two instances with the combinations accepting lowest score below threshold, if no score is below threshold, then keep the coordinates of the previous frame. After a full pass through the data, mark the frames where a score was accepted as reliable. The decaying weight allows the instance where the track has gone cold for many frames to jump to the other instance once it reappears

preprocess_coords <- function(coords_sort) {

  coords_sort %>%
    group_by(frame) %>%
    group_modify(~ {
      df <- .
      if (nrow(df) <= 2) return(df)  # nothing to do

      # Compute pairwise distances between mid-centers
      mids <- df %>% select(x_Mid_center = `x_Mid-center`, y_Mid_center = `y_Mid-center`)
      dists <- as.matrix(dist(mids))  # 3x3 symmetric

      # Get unique pairs
      pairs <- expand.grid(i = 1:3, j = 1:3)
      pairs <- pairs[pairs$i < pairs$j, ]
      pairs$dist <- mapply(function(i,j) dists[i,j], pairs$i, pairs$j)
      closest_pair <- pairs %>% slice_min(dist, with_ties = FALSE)

      # Remove the second of the closest pair
      remove_idx <- closest_pair$j
      df_clean <- df[-remove_idx, ]

      # Relabel instance as 0 and 1 using nrow()
      df_clean <- df_clean %>%
        mutate(instance = 0:(nrow(df_clean)-1))

      df_clean
    }) %>%
    ungroup()
}


coords_sort = preprocess_coords(coords_mid |> arrange(frame))
str(coords_sort)
tibble [35,080 × 22] (S3: tbl_df/tbl/data.frame)
 $ frame       : num [1:35080] 0 0 1 1 2 2 3 3 4 4 ...
 $ instance    : num [1:35080] 0 1 0 1 0 1 0 1 0 1 ...
 $ x_Ear-left  : num [1:35080] 1080 404 1084 404 1085 ...
 $ x_Ear-right : num [1:35080] 1025 400 1033 400 1052 ...
 $ x_Nose      : num [1:35080] 1085 468 1104 468 1117 ...
 $ x_Mid-center: num [1:35080] 1036 349 1037 349 1041 ...
 $ x_Mid-left  : num [1:35080] 1080 372 1080 372 1084 ...
 $ x_Mid-right : num [1:35080] 992 NA 993 NA 997 ...
 $ x_Tail-base : num [1:35080] 1064 272 1064 272 1061 ...
 $ x_Tail-tip  : num [1:35080] NA NA NA NA NA NA NA NA NA NA ...
 $ y_Ear-left  : num [1:35080] 612 609 605 612 601 ...
 $ y_Ear-right : num [1:35080] 648 676 649 676 656 ...
 $ y_Nose      : num [1:35080] 668 640 657 640 644 ...
 $ y_Mid-center: num [1:35080] 580 633 580 633 584 ...
 $ y_Mid-left  : num [1:35080] 588 601 584 600 580 ...
 $ y_Mid-right : num [1:35080] 577 NA 580 NA 588 ...
 $ y_Tail-base : num [1:35080] 520 605 517 605 516 ...
 $ y_Tail-tip  : num [1:35080] NA NA NA NA NA NA NA NA NA NA ...
 $ dtn         : num [1:35080] 150 199 145 199 140 ...
 $ dtm         : num [1:35080] 66.5 81.5 69.3 81.4 70.7 ...
 $ dnm         : num [1:35080] 100.4 119.5 101.8 119.6 97.4 ...
 $ dist_ratio  : num [1:35080] 0.662 0.682 0.681 0.681 0.726 ...

Helper Functions:

process_forward_combos <- function(coords_sort, r, priors, threshold,
                                   motion_lambda = 1, motion_alpha = 1,
                                   big_penalty = 1e6, too_close_thresh = 10,
                                   decay_rate = 0.8,
                                   verbose = FALSE) {
  # --- helpers ---
  nll <- function(d, mean, sd) {
    if (length(d) == 0 || is.na(d)) return(0)
    0.5 * ((d - mean) / sd)^2 + log(sd)
  }
  dist_xy <- function(a, b) {
    if (length(a) != 2 || length(b) != 2) return(NA_real_)
    if (any(is.na(a)) || any(is.na(b))) return(NA_real_)
    sqrt((a[1] - b[1])^2 + (a[2] - b[2])^2)
  }

  # --- checks ---
  required_cols <- c("frame", "instance", "x_Nose", "y_Nose",
                     "x_Mid-center", "y_Mid-center",
                     "x_Tail-base", "y_Tail-base")
  stopifnot(all(required_cols %in% colnames(coords_sort)))
  frames_all <- sort(unique(coords_sort$frame))
  if (! r %in% frames_all)
    stop("start frame r not present in coords_sort")

  combos <- expand.grid(nose_choice = 1:2,
                        mid_choice = 1:2,
                        tail_choice = 1:2,
                        stringsAsFactors = FALSE)
  frames_seq <- frames_all[frames_all >= r]

  # --- initialize from frame r ---
  fr_r_rows <- coords_sort %>% filter(frame == r) %>% arrange(instance)
  if (nrow(fr_r_rows) == 0) stop("No rows at start frame r")
  if (nrow(fr_r_rows) == 1) {
    pad_row <- fr_r_rows[1, , drop = FALSE]; pad_row[,] <- NA
    fr_r_rows <- bind_rows(fr_r_rows, pad_row)
  } else {
    fr_r_rows <- fr_r_rows[1:2, , drop = FALSE]
  }

  # prev accepted info
  # initialize per-part previous frames
prev_frame_nose0 <- prev_frame_mid0 <- prev_frame_tail0 <- r
prev_frame_nose1 <- prev_frame_mid1 <- prev_frame_tail1 <- r

  prev_nose0 <- c(fr_r_rows[1, "x_Nose"], fr_r_rows[1, "y_Nose"]) %>% as.numeric()
  prev_mid0  <- c(fr_r_rows[1, "x_Mid-center"], fr_r_rows[1, "y_Mid-center"]) %>% as.numeric()
  prev_tail0 <- c(fr_r_rows[1, "x_Tail-base"], fr_r_rows[1, "y_Tail-base"]) %>% as.numeric()
  prev_nose1 <- c(fr_r_rows[2, "x_Nose"], fr_r_rows[2, "y_Nose"]) %>% as.numeric()
  prev_mid1  <- c(fr_r_rows[2, "x_Mid-center"], fr_r_rows[2, "y_Mid-center"]) %>% as.numeric()
  prev_tail1 <- c(fr_r_rows[2, "x_Tail-base"], fr_r_rows[2, "y_Tail-base"]) %>% as.numeric()

  out <- vector("list", length(frames_seq))
  row_index <- 0

  # --- loop over frames ---
  for (f in frames_seq) {
    row_index <- row_index + 1
    fr_rows <- coords_sort %>% filter(frame == f) %>% arrange(instance)

    # keep copies of mid before any updates for the sanity swap check later
    prev_mid0_before <- prev_mid0
    prev_mid1_before <- prev_mid1
    prev_nose0_before <- prev_nose0
    prev_nose1_before <- prev_nose1
    prev_tail0_before <- prev_tail0
    prev_tail1_before <- prev_tail1

    # no detections: carry forward
    if (nrow(fr_rows) == 0) {
      out[[row_index]] <- tibble(
        frame = f, best_combo = NA_integer_,
        score0 = NA_real_, score1 = NA_real_,
        inst0_mid_x = prev_mid0[1], inst0_mid_y = prev_mid0[2],
        inst1_mid_x = prev_mid1[1], inst1_mid_y = prev_mid1[2],
        inst0_nose_x = prev_nose0[1], inst0_nose_y = prev_nose0[2],
        inst1_nose_x = prev_nose1[1], inst1_nose_y = prev_nose1[2],
        inst0_tail_x = prev_tail0[1], inst0_tail_y = prev_tail0[2],
        inst1_tail_x = prev_tail1[1], inst1_tail_y = prev_tail1[2],
        reliability0 = FALSE, reliability1 = FALSE
      )
      next
    }

    if (nrow(fr_rows) == 1) {
      pad_row <- fr_rows[1, , drop = FALSE]; pad_row[,] <- NA
      fr_rows <- bind_rows(fr_rows, pad_row)
    } else {
      fr_rows <- fr_rows[1:2, , drop = FALSE]
    }

# helper to check distance
too_close <- function(p1, p2, thresh) {
  if (any(is.na(p1)) || any(is.na(p2))) return(FALSE)
  sqrt(sum((p1 - p2)^2)) < thresh
}

# helper to fix candidates if too close
fix_too_close <- function(cands, too_close_thresh) {
  if (length(cands) >= 2) {
    if (too_close(cands[[1]], cands[[2]], too_close_thresh)) {
      mean_pt <- colMeans(rbind(cands[[1]], cands[[2]]), na.rm = TRUE)
      cands[[1]] <- mean_pt
      cands[[2]] <- c(NA, NA)
    }
  }
  cands
}

# candidate lists (with cleaning)
nose_cands <- list(
  c(fr_rows[["x_Nose"]][1], fr_rows[["y_Nose"]][1]) %>% as.numeric(),
  c(fr_rows[["x_Nose"]][2], fr_rows[["y_Nose"]][2]) %>% as.numeric()
) %>% fix_too_close(too_close_thresh)

mid_cands <- list(
  c(fr_rows[["x_Mid-center"]][1], fr_rows[["y_Mid-center"]][1]) %>% as.numeric(),
  c(fr_rows[["x_Mid-center"]][2], fr_rows[["y_Mid-center"]][2]) %>% as.numeric()
) %>% fix_too_close(too_close_thresh)

tail_cands <- list(
  c(fr_rows[["x_Tail-base"]][1], fr_rows[["y_Tail-base"]][1]) %>% as.numeric(),
  c(fr_rows[["x_Tail-base"]][2], fr_rows[["y_Tail-base"]][2]) %>% as.numeric()
) %>% fix_too_close(too_close_thresh)

# pre-allocate
combo_scores <- tibble(
  combo = seq_len(nrow(combos)),
  score0 = NA_real_, score1 = NA_real_,
  mid0x = NA_real_, mid0y = NA_real_,
  mid1x = NA_real_, mid1y = NA_real_,
  nose0x = NA_real_, nose0y = NA_real_,
  nose1x = NA_real_, nose1y = NA_real_,
  tail0x = NA_real_, tail0y = NA_real_,
  tail1x = NA_real_, tail1y = NA_real_
)

fill_with_prev <- function(curr, prev) {
  if (any(is.na(curr))) return(prev)
  curr
}

# --- compute scores ---
for (ci in seq_len(nrow(combos))) {
  nc <- combos$nose_choice[ci]
  mc <- combos$mid_choice[ci]
  tc <- combos$tail_choice[ci]

  # raw candidates
  nose0 <- nose_cands[[nc]]
  mid0  <- mid_cands[[mc]]
  tail0 <- tail_cands[[tc]]
  nose1 <- nose_cands[[3 - nc]]
  mid1  <- mid_cands[[3 - mc]]
  tail1 <- tail_cands[[3 - tc]]

  # replace NA with previous frame coords
  nose0 <- fill_with_prev(nose0, prev_nose0)
  mid0  <- fill_with_prev(mid0,  prev_mid0)
  tail0 <- fill_with_prev(tail0, prev_tail0)

  nose1 <- fill_with_prev(nose1, prev_nose1)
  mid1  <- fill_with_prev(mid1,  prev_mid1)
  tail1 <- fill_with_prev(tail1, prev_tail1)

  # ... scoring logic

      # distances
      d_nm0 <- dist_xy(nose0, mid0); d_tm0 <- dist_xy(tail0, mid0); d_tn0 <- dist_xy(tail0, nose0)
      d_nm1 <- dist_xy(nose1, mid1); d_tm1 <- dist_xy(tail1, mid1); d_tn1 <- dist_xy(tail1, nose1)
      
      frame_gap_nose0 <- max(1, f - prev_frame_nose0)
      frame_gap_mid0  <- max(1, f - prev_frame_mid0)
      frame_gap_tail0 <- max(1, f - prev_frame_tail0)
      
      frame_gap_nose1 <- max(1, f - prev_frame_nose1)
      frame_gap_mid1  <- max(1, f - prev_frame_mid1)
      frame_gap_tail1 <- max(1, f - prev_frame_tail1)

# --- per-part NLLs with decay ---
nll_nm0 <- if (!is.na(d_nm0)) nll(d_nm0, priors$NM["mean"], priors$NM["sd"]) *
             decay_rate^(max(frame_gap_nose0, frame_gap_mid0) - 1) else NA_real_
nll_tm0 <- if (!is.na(d_tm0)) nll(d_tm0, priors$TM["mean"], priors$TM["sd"]) *
             decay_rate^(max(frame_gap_tail0, frame_gap_mid0) - 1) else NA_real_
nll_tn0 <- if (!is.na(d_tn0)) nll(d_tn0, priors$TN["mean"], priors$TN["sd"]) *
             decay_rate^(max(frame_gap_tail0, frame_gap_nose0) - 1) else NA_real_

nll_nm1 <- if (!is.na(d_nm1)) nll(d_nm1, priors$NM["mean"], priors$NM["sd"]) *
             decay_rate^(max(frame_gap_nose1, frame_gap_mid1) - 1) else NA_real_
nll_tm1 <- if (!is.na(d_tm1)) nll(d_tm1, priors$TM["mean"], priors$TM["sd"]) *
             decay_rate^(max(frame_gap_tail1, frame_gap_mid1) - 1) else NA_real_
nll_tn1 <- if (!is.na(d_tn1)) nll(d_tn1, priors$TN["mean"], priors$TN["sd"]) *
             decay_rate^(max(frame_gap_tail1, frame_gap_nose1) - 1) else NA_real_

# --- average NLLs with big_penalty fallback ---
avg_nll0 <- if (all(is.na(c(nll_nm0, nll_tm0, nll_tn0)))) big_penalty else mean(c(nll_nm0, nll_tm0, nll_tn0), na.rm = TRUE)
avg_nll1 <- if (all(is.na(c(nll_nm1, nll_tm1, nll_tn1)))) big_penalty else mean(c(nll_nm1, nll_tm1, nll_tn1), na.rm = TRUE)

                 # --- motion per part ---
           # distances moved since last reliable detection
      dists0 <- c(
        nose = dist_xy(prev_nose0, nose0),
        mid  = dist_xy(prev_mid0,  mid0),
        tail = dist_xy(prev_tail0, tail0)
      )
      
      dists1 <- c(
        nose = dist_xy(prev_nose1, nose1),
        mid  = dist_xy(prev_mid1,  mid1),
        tail = dist_xy(prev_tail1, tail1)
      )
      
      # motion = max distance among body parts (for reference, if needed)
      motion0 <- if (all(is.na(dists0))) NA_real_ else max(dists0, na.rm = TRUE)
      motion1 <- if (all(is.na(dists1))) NA_real_ else max(dists1, na.rm = TRUE)
      
      # frame gaps per body part
      gaps0 <- c(
        nose = frame_gap_nose0,
        mid  = frame_gap_mid0,
        tail = frame_gap_tail0
      )
      
      gaps1 <- c(
        nose = frame_gap_nose1,
        mid  = frame_gap_mid1,
        tail = frame_gap_tail1
      )
      
      # motion penalties per body part
      penalties0 <- mapply(function(mot, gap) {
        if (!is.na(mot) && mot > (gap * motion_alpha)) {
          motion_lambda * (mot - gap * motion_alpha)
        } else {
          0
        }
      }, dists0, gaps0)
      
      penalties1 <- mapply(function(mot, gap) {
        if (!is.na(mot) && mot > (gap * motion_alpha)) {
          motion_lambda * (mot - gap * motion_alpha)
        } else {
          0
        }
      }, dists1, gaps1)
      
      # sum penalties per instance
      motion_pen0 <- sum(penalties0)
      motion_pen1 <- sum(penalties1)
      
      # final score including decayed NLL + motion
      combo_scores$score0[ci] <- avg_nll0 + motion_pen0
      combo_scores$score1[ci] <- avg_nll1 + motion_pen1
 
     
      combo_scores$mid0x[ci] <- mid0[1]; combo_scores$mid0y[ci] <- mid0[2]
      combo_scores$mid1x[ci] <- mid1[1]; combo_scores$mid1y[ci] <- mid1[2]
      combo_scores$nose0x[ci] <- nose0[1]; combo_scores$nose0y[ci] <- nose0[2]
      combo_scores$nose1x[ci] <- nose1[1]; combo_scores$nose1y[ci] <- nose1[2]
      combo_scores$tail0x[ci] <- tail0[1]; combo_scores$tail0y[ci] <- tail0[2]
      combo_scores$tail1x[ci] <- tail1[1]; combo_scores$tail1y[ci] <- tail1[2]
}


combo_scores <- combo_scores %>% mutate(mid_dist = sqrt((mid0x - mid1x)^2 + (mid0y - mid1y)^2) )
if(f == 5016){
  print(combo_scores)
}

    valid_both <- combo_scores %>%
      filter(!is.na(score0), !is.na(score1), score0 < threshold, score1 < threshold, mid_dist >= too_close_thresh)

    dist0_prev1 <- sqrt((combo_scores$mid0x - prev_mid1[1])^2 +
                        (combo_scores$mid0y - prev_mid1[2])^2)
    dist1_prev0 <- sqrt((combo_scores$mid1x - prev_mid0[1])^2 +
                        (combo_scores$mid1y - prev_mid0[2])^2)

    valid0 <- combo_scores %>%
      filter(!is.na(score0), score0 < threshold, dist0_prev1 >= too_close_thresh)
    valid1 <- combo_scores %>%
      filter(!is.na(score1), score1 < threshold, dist1_prev0 >= too_close_thresh)
    # --- decision ---
    chosen <- NULL; accept0 <- FALSE; accept1 <- FALSE
      # --- choose best combo ---
tol <- 1e-6  # tolerance for positive movement

if (nrow(valid_both) > 0) {
  comb_dist <- sqrt((valid_both$mid0x - prev_mid0[1])^2 +
                    (valid_both$mid0y - prev_mid0[2])^2) +
               sqrt((valid_both$mid1x - prev_mid1[1])^2 +
                    (valid_both$mid1y - prev_mid1[2])^2)
  chosen <- valid_both[which.min(comb_dist), ]
  accept0 <- TRUE; accept1 <- TRUE

} else {
  # compute distances
  d0 <- sqrt((valid0$mid0x - prev_mid0[1])^2 + (valid0$mid0y - prev_mid0[2])^2)
  d1 <- sqrt((valid1$mid1x - prev_mid1[1])^2 + (valid1$mid1y - prev_mid1[2])^2)

  # filter positive distances
  valid0_pos <- valid0[d0 > tol, , drop = FALSE]
  valid1_pos <- valid1[d1 > tol, , drop = FALSE]
  d0_pos <- d0[d0 > tol]
  d1_pos <- d1[d1 > tol]

  if (nrow(valid0_pos) > 0 && nrow(valid1_pos) > 0) {
    best0 <- valid0_pos[which.min(d0_pos), ]
    best1 <- valid1_pos[which.min(d1_pos), ]
    if (min(d0_pos) <= min(d1_pos)) {
      chosen <- best0; accept0 <- TRUE
    } else {
      chosen <- best1; accept1 <- TRUE
    }

  } else if (nrow(valid0_pos) > 0) {
    chosen <- valid0_pos[which.min(d0_pos), ]; accept0 <- TRUE

  } else if (nrow(valid1_pos) > 0) {
    chosen <- valid1_pos[which.min(d1_pos), ]; accept1 <- TRUE

  } else {
    # fallback: no valid positive movement
    chosen <- tibble(
      combo = NA_integer_,
      mid0x = prev_mid0[1], mid0y = prev_mid0[2],
      mid1x = prev_mid1[1], mid1y = prev_mid1[2],
      nose0x = prev_nose0[1], nose0y = prev_nose0[2],
      nose1x = prev_nose1[1], nose1y = prev_nose1[2],
      tail0x = prev_tail0[1], tail0y = prev_tail0[2],
      tail1x = prev_tail1[1], tail1y = prev_tail1[2],
      score0 = Inf, score1 = Inf
    )
  }
}


 # --- initialize per-part reliability ---
reliability_nose0 <- reliability_mid0 <- reliability_tail0 <- FALSE
reliability_nose1 <- reliability_mid1 <- reliability_tail1 <- FALSE

# --- update prev and reliabilities ---
if (accept0) {
  # Mid
  if (!any(is.na(chosen[c("mid0x","mid0y")]))) {
    moved <- any(c(chosen$mid0x, chosen$mid0y) != prev_mid0)
    prev_mid0 <- c(chosen$mid0x, chosen$mid0y)
    reliability_mid0 <- moved
    prev_frame_mid0 <- f
  }
  # Nose
  if (!any(is.na(chosen[c("nose0x","nose0y")]))) {
    moved <- any(c(chosen$nose0x, chosen$nose0y) != prev_nose0)
    prev_nose0 <- c(chosen$nose0x, chosen$nose0y)
    reliability_nose0 <- moved
    prev_frame_nose0 <- f
  }
  # Tail
  if (!any(is.na(chosen[c("tail0x","tail0y")]))) {
    moved <- any(c(chosen$tail0x, chosen$tail0y) != prev_tail0)
    prev_tail0 <- c(chosen$tail0x, chosen$tail0y)
    reliability_tail0 <- moved
    prev_frame_tail0 <- f
  }
}

if (accept1) {
  # Mid
  if (!any(is.na(chosen[c("mid1x","mid1y")]))) {
    moved <- any(c(chosen$mid1x, chosen$mid1y) != prev_mid1)
    prev_mid1 <- c(chosen$mid1x, chosen$mid1y)
    reliability_mid1 <- moved
    prev_frame_mid1 <- f
  }
  # Nose
  if (!any(is.na(chosen[c("nose1x","nose1y")]))) {
    moved <- any(c(chosen$nose1x, chosen$nose1y) != prev_nose1)
    prev_nose1 <- c(chosen$nose1x, chosen$nose1y)
    reliability_nose1 <- moved
    prev_frame_nose1 <- f
  }
  # Tail
  if (!any(is.na(chosen[c("tail1x","tail1y")]))) {
    moved <- any(c(chosen$tail1x, chosen$tail1y) != prev_tail1)
    prev_tail1 <- c(chosen$tail1x, chosen$tail1y)
    reliability_tail1 <- moved
    prev_frame_tail1 <- f
  }
}

# --- SANITY SWAP CHECK (ENTIRE INSTANCE SWAP, ALL PARTS) ---
if (
  !any(is.na(prev_mid0_before)) && !any(is.na(prev_mid1_before)) &&
  !any(is.na(prev_mid0)) && !any(is.na(prev_mid1)) &&
  !any(is.na(prev_nose0_before)) && !any(is.na(prev_nose1_before)) &&
  !any(is.na(prev_tail0_before)) && !any(is.na(prev_tail1_before))
) {
  # original assignment movement
  orig_sum <- dist_xy(prev_mid0_before,  prev_mid0)  +
              dist_xy(prev_mid1_before,  prev_mid1)  +
              dist_xy(prev_nose0_before, prev_nose0) +
              dist_xy(prev_nose1_before, prev_nose1) +
              dist_xy(prev_tail0_before, prev_tail0) +
              dist_xy(prev_tail1_before, prev_tail1)
  
  # swapped assignment movement
  swap_sum <- dist_xy(prev_mid0_before,  prev_mid1)  +
              dist_xy(prev_mid1_before,  prev_mid0)  +
              dist_xy(prev_nose0_before, prev_nose1) +
              dist_xy(prev_nose1_before, prev_nose0) +
              dist_xy(prev_tail0_before, prev_tail1) +
              dist_xy(prev_tail1_before, prev_tail0)
  
  if (!is.na(orig_sum) && !is.na(swap_sum) && swap_sum < 0.5 * orig_sum) {
    # perform full swap (mid, nose, tail, frames, etc.)
    tmp <- list(
      mid        = prev_mid0,
      nose       = prev_nose0,
      tail       = prev_tail0,
      frame      = prev_frame0,
      frame_mid  = prev_frame_mid0,
      frame_nose = prev_frame_nose0,
      frame_tail = prev_frame_tail0
    )
    
    prev_mid0        <- prev_mid1
    prev_nose0       <- prev_nose1
    prev_tail0       <- prev_tail1
    prev_frame0      <- prev_frame1
    prev_frame_mid0  <- prev_frame_mid1
    prev_frame_nose0 <- prev_frame_nose1
    prev_frame_tail0 <- prev_frame_tail1
    
    prev_mid1        <- tmp$mid
    prev_nose1       <- tmp$nose
    prev_tail1       <- tmp$tail
    prev_frame1      <- tmp$frame
    prev_frame_mid1  <- tmp$frame_mid
    prev_frame_nose1 <- tmp$frame_nose
    prev_frame_tail1 <- tmp$frame_tail
  }
}



# --- save result including per-part reliabilities ---
out[[row_index]] <- tibble(
  frame = f,
  best_combo = chosen$combo,
  score0 = chosen$score0, score1 = chosen$score1,
  inst0_mid_x = prev_mid0[1], inst0_mid_y = prev_mid0[2],
  inst1_mid_x = prev_mid1[1], inst1_mid_y = prev_mid1[2],
  inst0_nose_x = prev_nose0[1], inst0_nose_y = prev_nose0[2],
  inst1_nose_x = prev_nose1[1], inst1_nose_y = prev_nose1[2],
  inst0_tail_x = prev_tail0[1], inst0_tail_y = prev_tail0[2],
  inst1_tail_x = prev_tail1[1], inst1_tail_y = prev_tail1[2],
  reliability_mid0 = reliability_mid0, reliability_mid1 = reliability_mid1,
  reliability_nose0 = reliability_nose0, reliability_nose1 = reliability_nose1,
  reliability_tail0 = reliability_tail0, reliability_tail1 = reliability_tail1
)

    if (verbose && f %% 100 == 0) {
      message("Frame ", f, ": score0=", signif(chosen$score0, 4),
              ", score1=", signif(chosen$score1, 4),
              ", accept0=", accept0, ", accept1=", accept1)
    }
  }
  bind_rows(out)
}

Usage of the helper functions to track instances:

priors <- list(
  NM = c(mean = pri_dnm_mean, sd = pri_dnm_sd),
  TM = c(mean = pri_dtm_mean, sd = pri_dtm_sd),
  TN = c(mean = pri_dtn_mean, sd = pri_dtn_sd)
)

res <- process_forward_combos(
  coords_sort = coords_sort,
  r = 1,
  priors = priors,
  threshold = 20,
  too_close_thresh = 15,
  motion_lambda = 1,
  motion_alpha = 55,
  big_penalty = 1e5,
  decay_rate = 0.9,
  verbose = TRUE
)
Frame 100: score0=3.168, score1=2.901, accept0=TRUE, accept1=TRUE
Frame 200: score0=4.171, score1=2.969, accept0=TRUE, accept1=TRUE
Frame 300: score0=3.238, score1=2.97, accept0=TRUE, accept1=TRUE
Frame 400: score0=2.943, score1=8.053, accept0=TRUE, accept1=TRUE
Frame 500: score0=6.791, score1=4.292, accept0=TRUE, accept1=TRUE
Frame 600: score0=5.187, score1=5.42, accept0=TRUE, accept1=TRUE
Frame 700: score0=10.1, score1=3.191, accept0=TRUE, accept1=TRUE
Frame 800: score0=3.794, score1=3.291, accept0=TRUE, accept1=TRUE
Frame 900: score0=3.603, score1=3.544, accept0=TRUE, accept1=TRUE
Frame 1000: score0=5.575, score1=3.097, accept0=TRUE, accept1=TRUE
Frame 1100: score0=3.484, score1=3.814, accept0=TRUE, accept1=TRUE
Frame 1200: score0=3.134, score1=5.169, accept0=TRUE, accept1=TRUE
Frame 1300: score0=10.52, score1=4.969, accept0=TRUE, accept1=TRUE
Frame 1400: score0=Inf, score1=Inf, accept0=FALSE, accept1=FALSE
Frame 1500: score0=3.275, score1=8.136, accept0=TRUE, accept1=TRUE
Frame 1600: score0=2.978, score1=3.233, accept0=TRUE, accept1=TRUE
Frame 1700: score0=2.977, score1=3.902, accept0=TRUE, accept1=TRUE
Frame 1800: score0=26.81, score1=9.825, accept0=FALSE, accept1=TRUE
Frame 1900: score0=4.93, score1=26.65, accept0=TRUE, accept1=FALSE
Frame 2000: score0=Inf, score1=Inf, accept0=FALSE, accept1=FALSE
Frame 2100: score0=6.463, score1=3.644, accept0=TRUE, accept1=TRUE
Frame 2200: score0=3.237, score1=2.933, accept0=TRUE, accept1=TRUE
Frame 2300: score0=2.931, score1=3.216, accept0=TRUE, accept1=TRUE
Frame 2400: score0=3.151, score1=3.433, accept0=TRUE, accept1=TRUE
Frame 2500: score0=4.154, score1=3.022, accept0=TRUE, accept1=TRUE
Frame 2600: score0=3.113, score1=2.97, accept0=TRUE, accept1=TRUE
Frame 2700: score0=3.021, score1=3.269, accept0=TRUE, accept1=TRUE
Frame 2800: score0=19.54, score1=4.129, accept0=TRUE, accept1=TRUE
Frame 2900: score0=3.688, score1=8.453, accept0=TRUE, accept1=TRUE
Frame 3000: score0=3.032, score1=3.075, accept0=TRUE, accept1=TRUE
Frame 3100: score0=10.3, score1=4.515, accept0=TRUE, accept1=TRUE
Frame 3200: score0=4.866, score1=8.428, accept0=TRUE, accept1=TRUE
Frame 3300: score0=17.22, score1=4.996, accept0=TRUE, accept1=TRUE
Frame 3400: score0=7.171, score1=61.78, accept0=TRUE, accept1=FALSE
Frame 3500: score0=4.718, score1=2.908, accept0=TRUE, accept1=TRUE
Frame 3600: score0=3.299, score1=4.619, accept0=TRUE, accept1=TRUE
Frame 3700: score0=2.918, score1=3.569, accept0=TRUE, accept1=TRUE
Frame 3800: score0=3.484, score1=3.456, accept0=TRUE, accept1=TRUE
Frame 3900: score0=24.52, score1=4.9, accept0=FALSE, accept1=TRUE
Frame 4000: score0=Inf, score1=Inf, accept0=FALSE, accept1=FALSE
Frame 4100: score0=2.977, score1=3.188, accept0=TRUE, accept1=TRUE
Frame 4200: score0=29.19, score1=7.83, accept0=FALSE, accept1=TRUE
Frame 4300: score0=8.205, score1=4.224, accept0=TRUE, accept1=TRUE
Frame 4400: score0=6.43, score1=3.527, accept0=TRUE, accept1=TRUE
Frame 4500: score0=3.241, score1=5.465, accept0=TRUE, accept1=TRUE
Frame 4600: score0=14.2, score1=2.908, accept0=TRUE, accept1=TRUE
Frame 4700: score0=3.084, score1=2.917, accept0=TRUE, accept1=TRUE
Frame 4800: score0=13.4, score1=7.64, accept0=TRUE, accept1=TRUE
Frame 4900: score0=15.97, score1=8.491, accept0=TRUE, accept1=TRUE
Frame 5000: score0=8.186, score1=204.2, accept0=TRUE, accept1=FALSE
# A tibble: 8 × 16
  combo score0 score1 mid0x mid0y mid1x mid1y nose0x nose0y nose1x nose1y tail0x
  <int>  <dbl>  <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
1     1  186.   131.  1281.  429. 1301.  585.  1292.   456.  1208.   656.  1229.
2     2  359.   131.  1281.  429. 1301.  585.  1208.   656.  1224.   649.  1229.
3     3   99.4   21.0 1301.  585. 1281.  429.  1292.   456.  1208.   656.  1229.
4     4  264.    17.9 1301.  585. 1281.  429.  1208.   656.  1224.   649.  1229.
5     5  113.   130.  1281.  429. 1301.  585.  1292.   456.  1208.   656.  1333.
6     6  279.   130.  1281.  429. 1301.  585.  1208.   656.  1224.   649.  1333.
7     7   11.5   20.7 1301.  585. 1281.  429.  1292.   456.  1208.   656.  1333.
8     8  170.    17.6 1301.  585. 1281.  429.  1208.   656.  1224.   649.  1333.
# ℹ 4 more variables: tail0y <dbl>, tail1x <dbl>, tail1y <dbl>, mid_dist <dbl>
Frame 5100: score0=4.806, score1=7.471, accept0=TRUE, accept1=TRUE
Frame 5200: score0=3.387, score1=3.099, accept0=TRUE, accept1=TRUE
Frame 5300: score0=4.968, score1=9.529, accept0=TRUE, accept1=TRUE
Frame 5400: score0=2.897, score1=3.055, accept0=TRUE, accept1=TRUE
Frame 5500: score0=3.178, score1=3.049, accept0=TRUE, accept1=TRUE
Frame 5600: score0=17.49, score1=14.65, accept0=TRUE, accept1=TRUE
Frame 5700: score0=10.54, score1=26.26, accept0=TRUE, accept1=FALSE
Frame 5800: score0=3.261, score1=3.008, accept0=TRUE, accept1=TRUE
Frame 5900: score0=3.267, score1=2.917, accept0=TRUE, accept1=TRUE
Frame 6000: score0=3.536, score1=3.976, accept0=TRUE, accept1=TRUE
Frame 6100: score0=3.334, score1=4.273, accept0=TRUE, accept1=TRUE
Frame 6200: score0=Inf, score1=Inf, accept0=FALSE, accept1=FALSE
Frame 6300: score0=3.292, score1=3.079, accept0=TRUE, accept1=TRUE
Frame 6400: score0=11.02, score1=12, accept0=TRUE, accept1=TRUE
Frame 6500: score0=3.005, score1=3, accept0=TRUE, accept1=TRUE
Frame 6600: score0=3.702, score1=3.805, accept0=TRUE, accept1=TRUE
Frame 6700: score0=3.013, score1=27.71, accept0=TRUE, accept1=FALSE
Frame 6800: score0=71.54, score1=3.236, accept0=FALSE, accept1=TRUE
Frame 6900: score0=8.717, score1=16, accept0=TRUE, accept1=TRUE
Frame 7000: score0=3.701, score1=3.97, accept0=TRUE, accept1=TRUE
Frame 7100: score0=3.111, score1=4.856, accept0=TRUE, accept1=TRUE
Frame 7200: score0=Inf, score1=Inf, accept0=FALSE, accept1=FALSE
Frame 7300: score0=68.33, score1=3.261, accept0=FALSE, accept1=TRUE
Frame 7400: score0=11.89, score1=13.06, accept0=TRUE, accept1=TRUE
Frame 7500: score0=3.327, score1=3.301, accept0=TRUE, accept1=TRUE
Frame 7600: score0=12.01, score1=7.762, accept0=TRUE, accept1=TRUE
Frame 7700: score0=8.767, score1=122.1, accept0=TRUE, accept1=FALSE
Frame 7800: score0=2.92, score1=3.28, accept0=TRUE, accept1=TRUE
Frame 7900: score0=12.19, score1=14.66, accept0=TRUE, accept1=TRUE
Frame 8000: score0=3.135, score1=3.196, accept0=TRUE, accept1=TRUE
Frame 8100: score0=3.88, score1=3.791, accept0=TRUE, accept1=TRUE
Frame 8200: score0=3.544, score1=6.323, accept0=TRUE, accept1=TRUE
Frame 8300: score0=12.41, score1=3.337, accept0=TRUE, accept1=TRUE
Frame 8400: score0=3.885, score1=5.379, accept0=TRUE, accept1=TRUE
Frame 8500: score0=5.614, score1=3.04, accept0=TRUE, accept1=TRUE
Frame 8600: score0=Inf, score1=Inf, accept0=FALSE, accept1=FALSE
Frame 8700: score0=3.83, score1=3.021, accept0=TRUE, accept1=TRUE
Frame 8800: score0=3.744, score1=3.133, accept0=TRUE, accept1=TRUE
Frame 8900: score0=3.646, score1=10.92, accept0=TRUE, accept1=TRUE
Frame 9000: score0=11.73, score1=11.89, accept0=TRUE, accept1=TRUE
Frame 9100: score0=12.85, score1=11.63, accept0=TRUE, accept1=TRUE
Frame 9200: score0=2.939, score1=2.919, accept0=TRUE, accept1=TRUE
Frame 9300: score0=117.2, score1=4.158, accept0=FALSE, accept1=TRUE
Frame 9400: score0=3.263, score1=2.967, accept0=TRUE, accept1=TRUE
Frame 9500: score0=2.954, score1=3.768, accept0=TRUE, accept1=TRUE
Frame 9600: score0=2.906, score1=2.267, accept0=TRUE, accept1=TRUE
Frame 9700: score0=2.915, score1=48.36, accept0=TRUE, accept1=FALSE
Frame 9800: score0=11.74, score1=12.15, accept0=TRUE, accept1=TRUE
Frame 9900: score0=3.6, score1=4.195, accept0=TRUE, accept1=TRUE
Frame 10000: score0=6.919, score1=8.92, accept0=TRUE, accept1=TRUE
Frame 10100: score0=4.124, score1=9.153, accept0=TRUE, accept1=TRUE
Frame 10200: score0=3.101, score1=4.182, accept0=TRUE, accept1=TRUE
Frame 10300: score0=3.415, score1=2.944, accept0=TRUE, accept1=TRUE
Frame 10400: score0=2.96, score1=2.976, accept0=TRUE, accept1=TRUE
Frame 10500: score0=3.207, score1=3.117, accept0=TRUE, accept1=TRUE
Frame 10600: score0=3.001, score1=3.151, accept0=TRUE, accept1=TRUE
Frame 10700: score0=3.009, score1=2.984, accept0=TRUE, accept1=TRUE
Frame 10800: score0=3.604, score1=2.908, accept0=TRUE, accept1=TRUE
Frame 10900: score0=4.488, score1=5.574, accept0=TRUE, accept1=TRUE
Frame 11000: score0=2.892, score1=2.904, accept0=TRUE, accept1=TRUE
Frame 11100: score0=3.755, score1=3.274, accept0=TRUE, accept1=TRUE
Frame 11200: score0=3.256, score1=3.202, accept0=TRUE, accept1=TRUE
Frame 11300: score0=3.179, score1=4.039, accept0=TRUE, accept1=TRUE
Frame 11400: score0=3.055, score1=3.234, accept0=TRUE, accept1=TRUE
Frame 11500: score0=3.234, score1=2.959, accept0=TRUE, accept1=TRUE
Frame 11600: score0=3.29, score1=37.03, accept0=TRUE, accept1=FALSE
Frame 11700: score0=11.15, score1=18.34, accept0=TRUE, accept1=TRUE
Frame 11800: score0=5.583, score1=2.916, accept0=TRUE, accept1=TRUE
Frame 11900: score0=12.04, score1=3.728, accept0=TRUE, accept1=TRUE
Frame 12000: score0=9.021, score1=9.622, accept0=TRUE, accept1=TRUE
Frame 12100: score0=9.031, score1=14.85, accept0=TRUE, accept1=TRUE
Frame 12200: score0=4.509, score1=3, accept0=TRUE, accept1=TRUE
Frame 12300: score0=2.977, score1=3.038, accept0=TRUE, accept1=TRUE
Frame 12400: score0=3.207, score1=2.973, accept0=TRUE, accept1=TRUE
Frame 12500: score0=2.952, score1=4.495, accept0=TRUE, accept1=TRUE
Frame 12600: score0=14.59, score1=4.546, accept0=TRUE, accept1=TRUE
Frame 12700: score0=3.242, score1=52.13, accept0=TRUE, accept1=FALSE
Frame 12800: score0=2.897, score1=3.143, accept0=TRUE, accept1=TRUE
Frame 12900: score0=3.927, score1=3.119, accept0=TRUE, accept1=TRUE
Frame 13000: score0=4.323, score1=3.156, accept0=TRUE, accept1=TRUE
Frame 13100: score0=3.436, score1=4.791, accept0=TRUE, accept1=TRUE
Frame 13200: score0=3.294, score1=2.911, accept0=TRUE, accept1=TRUE
Frame 13300: score0=250.1, score1=16.54, accept0=FALSE, accept1=TRUE
Frame 13400: score0=3.132, score1=3.284, accept0=TRUE, accept1=TRUE
Frame 13500: score0=6.151, score1=11.91, accept0=TRUE, accept1=TRUE
Frame 13600: score0=11.07, score1=3.489, accept0=TRUE, accept1=TRUE
Frame 13700: score0=3.162, score1=4.101, accept0=TRUE, accept1=TRUE
Frame 13800: score0=2.905, score1=4.075, accept0=TRUE, accept1=TRUE
Frame 13900: score0=9.014, score1=5.361, accept0=TRUE, accept1=TRUE
Frame 14000: score0=3.012, score1=2.915, accept0=TRUE, accept1=TRUE
Frame 14100: score0=613.6, score1=10.15, accept0=FALSE, accept1=TRUE
Frame 14200: score0=Inf, score1=Inf, accept0=FALSE, accept1=FALSE
Frame 14300: score0=3.486, score1=3.167, accept0=TRUE, accept1=TRUE
Frame 14400: score0=4.064, score1=15.05, accept0=TRUE, accept1=TRUE
Frame 14500: score0=3.013, score1=4.128, accept0=TRUE, accept1=TRUE
Frame 14600: score0=12.23, score1=56.6, accept0=TRUE, accept1=FALSE
Frame 14700: score0=46.35, score1=16.62, accept0=FALSE, accept1=TRUE
Frame 14800: score0=3.64, score1=3.902, accept0=TRUE, accept1=TRUE
Frame 14900: score0=5.912, score1=2.893, accept0=TRUE, accept1=TRUE
Frame 15000: score0=4.542, score1=3.79, accept0=TRUE, accept1=TRUE
Frame 15100: score0=3.015, score1=7.398, accept0=TRUE, accept1=TRUE
Frame 15200: score0=3.62, score1=2.937, accept0=TRUE, accept1=TRUE
Frame 15300: score0=3.37, score1=2.959, accept0=TRUE, accept1=TRUE
Frame 15400: score0=3.154, score1=2.893, accept0=TRUE, accept1=TRUE
Frame 15500: score0=3.53, score1=3.042, accept0=TRUE, accept1=TRUE
Frame 15600: score0=7.482, score1=2.907, accept0=TRUE, accept1=TRUE
Frame 15700: score0=10.59, score1=3.815, accept0=TRUE, accept1=TRUE
Frame 15800: score0=5.467, score1=3.215, accept0=TRUE, accept1=TRUE
Frame 15900: score0=3.638, score1=3.407, accept0=TRUE, accept1=TRUE
Frame 16000: score0=3.026, score1=3.123, accept0=TRUE, accept1=TRUE
Frame 16100: score0=3.222, score1=3.529, accept0=TRUE, accept1=TRUE
Frame 16200: score0=3.415, score1=5.462, accept0=TRUE, accept1=TRUE
Frame 16300: score0=5.91, score1=3.72, accept0=TRUE, accept1=TRUE
Frame 16400: score0=17.8, score1=20.06, accept0=TRUE, accept1=FALSE
Frame 16500: score0=3.112, score1=2.964, accept0=TRUE, accept1=TRUE
Frame 16600: score0=5.54, score1=9.51, accept0=TRUE, accept1=TRUE
Frame 16700: score0=4.361, score1=3.082, accept0=TRUE, accept1=TRUE
Frame 16800: score0=6.951, score1=7.217, accept0=TRUE, accept1=TRUE
Frame 16900: score0=6.51, score1=13.32, accept0=TRUE, accept1=TRUE
Frame 17000: score0=10.32, score1=17.61, accept0=TRUE, accept1=TRUE
Frame 17100: score0=3.204, score1=2.962, accept0=TRUE, accept1=TRUE
Frame 17200: score0=10.71, score1=3.652, accept0=TRUE, accept1=TRUE
Frame 17300: score0=4.279, score1=3.804, accept0=TRUE, accept1=TRUE
Frame 17400: score0=3.713, score1=16.01, accept0=TRUE, accept1=TRUE
Frame 17500: score0=17.97, score1=7.546, accept0=TRUE, accept1=TRUE
Frame 17600: score0=3.1, score1=11.43, accept0=TRUE, accept1=TRUE
Frame 17700: score0=2.977, score1=4.33, accept0=TRUE, accept1=TRUE
Frame 17800: score0=3.811, score1=3.091, accept0=TRUE, accept1=TRUE
Frame 17900: score0=3.112, score1=3.471, accept0=TRUE, accept1=TRUE
Frame 18000: score0=3.04, score1=3.109, accept0=TRUE, accept1=TRUE
library(dplyr)
library(ggplot2)

# Helper function to reshape and plot with part-specific reliability
plot_part_axis <- function(res, part, axis = c("x","y"), frame_range = c(5000, 6000)) {
  axis <- match.arg(axis)
  
  # column names
  col0 <- paste0("inst0_", part, "_", axis)
  col1 <- paste0("inst1_", part, "_", axis)
  
  # part-specific reliabilities
  rel0 <- paste0("reliability_", part, "0")
  rel1 <- paste0("reliability_", part, "1")
  
  # reshape for plotting
  df0 <- res %>%
    filter(frame >= frame_range[1], frame <= frame_range[2]) %>%
    transmute(frame, value = .data[[col0]], reliable = .data[[rel0]], instance = "inst0")
  
  df1 <- res %>%
    filter(frame >= frame_range[1], frame <= frame_range[2]) %>%
    transmute(frame, value = .data[[col1]], reliable = .data[[rel1]], instance = "inst1")
  
  df_all <- bind_rows(df0, df1)
  
  # plot
  ggplot(df_all, aes(x = frame, y = value, group = instance)) +
    geom_line(aes(color = instance), size = 0.5) +
    geom_point(data = df_all %>% filter(!reliable), aes(x = frame, y = value),
               color = "black", size = 0.8) +
    scale_color_manual(values = c("inst0" = "red", "inst1" = "blue")) +
    labs(title = paste(part, toupper(axis), "Frames", frame_range[1], "-", frame_range[2])) +
    theme_minimal() +
    theme(plot.title = element_text(hjust = 0.5))
}

# --- Plot all parts and axes ---
parts <- c("mid", "nose", "tail")
axes  <- c("x", "y")

for (p in parts) {
  for (a in axes) {
    print(plot_part_axis(res, p, a))
  }
}
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.

# --- per-part summary (unchanged) ---
reliability_summary <- function(res, part) {
  tibble(
    part = part,
    inst0_reliable   = sum(res[[paste0("reliability_", part, "0")]], na.rm = TRUE),
    inst0_unreliable = sum(!res[[paste0("reliability_", part, "0")]], na.rm = TRUE),
    inst1_reliable   = sum(res[[paste0("reliability_", part, "1")]], na.rm = TRUE),
    inst1_unreliable = sum(!res[[paste0("reliability_", part, "1")]], na.rm = TRUE)
  )
}

reliability_tables <- bind_rows(lapply(parts, reliability_summary, res = res))

# --- overall all/none reliable by instance ---
overall_reliability <- tibble(
  part = "ALL",
  inst0_reliable   = sum(res$reliability_mid0 & res$reliability_nose0 & res$reliability_tail0, na.rm = TRUE),
  inst0_unreliable = sum(!res$reliability_mid0 & !res$reliability_nose0 & !res$reliability_tail0, na.rm = TRUE),
  inst1_reliable   = sum(res$reliability_mid1 & res$reliability_nose1 & res$reliability_tail1, na.rm = TRUE),
  inst1_unreliable = sum(!res$reliability_mid1 & !res$reliability_nose1 & !res$reliability_tail1, na.rm = TRUE)
)

# combine into one table
reliability_tables <- bind_rows(reliability_tables, overall_reliability)

print(reliability_tables)
# A tibble: 4 × 5
  part  inst0_reliable inst0_unreliable inst1_reliable inst1_unreliable
  <chr>          <int>            <int>          <int>            <int>
1 mid            13836             4221          13741             4316
2 nose           11663             6394          10501             7556
3 tail           13601             4456          13583             4474
4 ALL             8939             2080           8013             1888
# Create reliable-only data frames for each body part × instance
nose0_reliable_df <- res %>% filter(reliability_nose0)
mid0_reliable_df  <- res %>% filter(reliability_mid0)
tail0_reliable_df <- res %>% filter(reliability_tail0)

nose1_reliable_df <- res %>% filter(reliability_nose1)
mid1_reliable_df  <- res %>% filter(reliability_mid1)
tail1_reliable_df <- res %>% filter(reliability_tail1)

After the first phase, around 3/4 of the frames can be confirmed as reliable. Among the rest, some can be easily confirmed/extrapolated with nearby frames, while some have longer empty periods will need more advanced techniques to fill in such as Kalman filters and posterior density. Kalman filter both forwards and backwards can become a smoother(RTS) to predict the unreliable frames. To maintain conjugacy, the prior for the skeleton would need to be composed of the conditional Gaussian distribution based on the skeleton, as well as the Kalman filter from previous time points. The likelihood would then be based on the observed other body parts from the current frame.

Animation plot for frames 6000-8000 of the two mice, unreliable points colored as black:

res_1000 <- res %>% filter(frame >= 6000, frame <= 8000)

# --- Pivot coordinates for instance 0 ---
coords0 <- res_1000 %>%
  select(frame,
         inst0_nose_x, inst0_nose_y,
         inst0_mid_x,  inst0_mid_y,
         inst0_tail_x, inst0_tail_y) %>%
  pivot_longer(
    cols = -frame,
    names_to = c("body_part", ".value"),
    names_pattern = "inst0_(nose|mid|tail)_(x|y)"
  ) %>%
  mutate(body_part = recode(body_part,
                            nose = "Nose",
                            mid  = "Mid-center",
                            tail = "Tail-base"),
         instance = "Instance 0")

# --- Pivot reliabilities for instance 0 ---
reliab0 <- res_1000 %>%
  select(frame,
         reliability_nose0,
         reliability_mid0,
         reliability_tail0) %>%
  pivot_longer(
    cols = -frame,
    names_to = "body_part",
    names_pattern = "reliability_(nose|mid|tail)0"
  ) %>%
  rename(reliable = value) %>%
  mutate(body_part = recode(body_part,
                            nose = "Nose",
                            mid  = "Mid-center",
                            tail = "Tail-base"),
         instance = "Instance 0")

# --- Pivot coordinates for instance 1 ---
coords1 <- res_1000 %>%
  select(frame,
         inst1_nose_x, inst1_nose_y,
         inst1_mid_x,  inst1_mid_y,
         inst1_tail_x, inst1_tail_y) %>%
  pivot_longer(
    cols = -frame,
    names_to = c("body_part", ".value"),
    names_pattern = "inst1_(nose|mid|tail)_(x|y)"
  ) %>%
  mutate(body_part = recode(body_part,
                            nose = "Nose",
                            mid  = "Mid-center",
                            tail = "Tail-base"),
         instance = "Instance 1")

# --- Pivot reliabilities for instance 1 ---
reliab1 <- res_1000 %>%
  select(frame,
         reliability_nose1,
         reliability_mid1,
         reliability_tail1) %>%
  pivot_longer(
    cols = -frame,
    names_to = "body_part",
    names_pattern = "reliability_(nose|mid|tail)1"
  ) %>%
  rename(reliable = value) %>%
  mutate(body_part = recode(body_part,
                            nose = "Nose",
                            mid  = "Mid-center",
                            tail = "Tail-base"),
         instance = "Instance 1")

# --- Combine instances ---
coords <- bind_rows(coords0, coords1)
reliab <- bind_rows(reliab0, reliab1)
df_anim <- left_join(coords, reliab,
                     by = c("frame", "body_part", "instance"))

# --- Plot ---
fig <- plot_ly(df_anim %>% filter(reliable, instance == "Instance 0"),
               x = ~x, y = ~y, frame = ~frame,
               color = ~body_part,
               colors = c("Nose" = "red",
                          "Mid-center" = "blue",
                          "Tail-base" = "green"),
               type = 'scatter', mode = 'markers+lines',
               line = list(width = 2),
               marker = list(symbol = "circle", size = 8)) %>%
  add_trace(data = df_anim %>% filter(!reliable, instance == "Instance 0"),
            x = ~x, y = ~y, frame = ~frame,
            type = 'scatter', mode = 'markers',
            marker = list(symbol = "circle", size = 8, color = "black"),
            inherit = FALSE,
            showlegend = FALSE) %>%
  add_trace(data = df_anim %>% filter(reliable, instance == "Instance 1"),
            x = ~x, y = ~y, frame = ~frame,
            color = ~body_part,
            colors = c("Nose" = "red",
                       "Mid-center" = "blue",
                       "Tail-base" = "green"),
            type = 'scatter', mode = 'markers+lines',
            line = list(width = 2),
            marker = list(symbol = "triangle-up", size = 8),
            inherit = FALSE) %>%
  add_trace(data = df_anim %>% filter(!reliable, instance == "Instance 1"),
            x = ~x, y = ~y, frame = ~frame,
            type = 'scatter', mode = 'markers',
            marker = list(symbol = "triangle-up", size = 8, color = "black"),
            inherit = FALSE,
            showlegend = FALSE) %>%
  layout(
    title = "Instances 0 (circle) and 1 (triangle) with reliability (black = unreliable, dots only)",
    xaxis = list(title = "X"),
    yaxis = list(title = "Y", scaleanchor = "x")
  ) %>%
  animation_opts(
    frame = 40,  # ms per frame
    transition = 0,
    redraw = FALSE
  )

fig
# For result: full counts of reliability
table(res$reliability_mid1)

FALSE  TRUE 
 4316 13741 
table(res$reliability_mid0)

FALSE  TRUE 
 4221 13836 

After the first phase, around 2/3 of the frames can be confirmed as reliable. Among the rest, some can be easily confirmed/extrapolated with nearby frames, while some have longer empty periods will need more advanced techniques to fill in such as Kalman filters and posterior density. Kalman filter both forwards and backwards can become a smoother(RTS) to predict the unreliable frames. To maintain conjugacy, the prior for the skeleton would need to be composed of the conditional Gaussian distribution based on the skeleton, as well as the Kalman filter from previous time points. The likelihood would then be based on the observed other body parts from the current frame.

Main problems:

First method gets high percentage of reliable frames for a single instance of a single body part, but it treats the two instances asymmetrically and causes wrong acceptances.

Second method uses negative log-likelihood score for each combination of instances for all body parts at once, but this is too strict and gets too low percentage of reliable body parts, along with many cases when a single body part is unreliable and affects the rest. This seems most promising so far, I can try to increase the accept rate but if there’s no better choice, I will use this for the next step of Kalman filtering.

The third method uses Hungarian matching of instances to get a reliable track for a single anchor body part (e.g. mid-center), and it is generalizable for more instances and body parts. Currently, the biggest challenger is how to enforce the skeleton structure in this method, currently it can only handle body parts separately and causes bad consistency of skeletons.